library(data.table)
library(mefa)
library(doParallel)
library(Hmisc)


setPars <- function(A = NULL,B = NULL, S = NULL,moment_weights = NULL) {
  
  pars <- list()
  
  ## Simulation parameters
  pars$N_SIM        <- 1000
  pars$BLP_TOL      <- 1e-5
  pars$BLP_ATTEMPTS <- 100
  pars$SEED         <- as.numeric(Sys.time())
  pars$CORES        <- 6
  pars$EST_YEARS    <- 2010:2015
  
  ## Structural parameters
  if(is.null(A)) {
    A = matrix(data = c(1.14,5.79,6.91,12.22),
               nrow = 4,ncol = 1,
               dimnames = list(c('alpha','beta','gamma','log_f'),c('A')))
  }
  
  if(is.null(B)) {
    B = matrix(data = c(0.64,-1.87,-3.81,0.38,
                     -0.57, 0.73, 2.30,0.39),
             nrow = 4,ncol = 2,
             dimnames = list(c('alpha','beta','gamma','log_f'),c('log_income','log_house_price')))
  }
  
  if(is.null(S)) {
    S = matrix(data = c(0.07,0,0,0,
                      0,1.36,0,0,
                      0,0,0.05,0,
                      0,0,0,0.49),
             nrow = 4,ncol = 4,
             dimnames = list(c('alpha','beta','gamma','log_f'),c('alpha','beta','gamma','log_f')))
  }
  
  # Legal parameters
  pars$LTV_LIMIT = 0.90 
  pars$BUNCHING_DELTA = 0.005
  
  
  # Assign things
  pars$A <- A
  pars$B <- B
  pars$S <- S
  
  # Assign moment weights
  if(is.null(moment_weights)) {
    moment_weights = 'inverse_mean'
  }
  pars$moment_weights = moment_weights
  
  # Diagnostic output
  pars$diag.show_moments <- T
  pars$diag.show_linear <- T
  pars$diag.show_parameters <- T
  pars$diag.show_convergence <- T
  
  pars$initial_guess = c(       5.79,      12.22, # mu
                                0.64,-1.87,-3.81, 0.38, # pi_income
                                -0.57, 0.73, 2.30, 0.39, # pi_price
                                0.07, 1.36, 0.05, 0.49) # sigma
  
  return(pars)
  
}


calculateMoments <- function(data,linear.model) {
  
  # Calculates a bunch of moments
  
  # Moments involving the residual of the regression 
  # Get xi but using "real" interest rates
  data[,RATE_INST := RATE]
  data[,xi := delta - predict(linear.model,data)]
  moments.residuals <- rbind(
    data.table(moment = 'xi_var',model = var(data$xi,na.rm = T),data = 0),
    data.table(moment = 'xi_income_covar',model = cov(data$xi,data$market.log_income.mean,use = 'pairwise.complete.obs'),data = 0),
    data.table(moment = 'xi_price_covar' ,model = cov(data$xi,data$market.log_price.mean,use = 'pairwise.complete.obs'),data = 0)
  )
  
  
  # Moments involving bunching
  data[,bunching.bucket := cut2(BUNCHING_MARKET_MODEL,g=4,levels.mean = T)]
  bunching.quantile <- data[,j=list(model = mean(BUNCHING_MARKET_MODEL,na.rm=T),data = mean(market.bunching,na.rm=T)),by=bunching.bucket]
  bunching.quantile$Quantile <- c('Q1','Q2','Q3','Q4')
  
  moments.bunching <- rbind(
    data.table(moment = 'bunching_mean',  data = mean(data$market.bunching,na.rm=T),model = mean(data$BUNCHING_MARKET_MODEL,na.rm=T)),
    data.table(moment = 'just_above_mean',data = mean(data$market.just.above,na.rm=T),model = mean(data$JUST_ABOVE_MODEL,na.rm=T)),
    data.table(moment = 'just_below_mean',data = mean(data$market.just.below,na.rm=T),model = mean(data$JUST_BELOW_MODEL,na.rm=T)),
    data.table(moment = 'bunching_distance',model = var(data$market.bunching - data$BUNCHING_MARKET_MODEL,na.rm=T),data = 0),
    data.table(moment = 'bunching_income_relationship', data = cov(data$market.bunching,data$market.log_income.mean,use = 'pairwise.complete.obs'), model = cov(data$BUNCHING_MARKET_MODEL,data$market.log_income.mean,use = 'pairwise.complete.obs')),
    data.table(moment = 'bunching_price_relationship' , data = cov(data$market.bunching,data$market.log_price.mean,use = 'pairwise.complete.obs'),  model = cov(data$BUNCHING_MARKET_MODEL,data$market.log_price.mean,use = 'pairwise.complete.obs')),
    data.table(moment = 'bunching_q1', model = bunching.quantile[Quantile == 'Q1']$model, data = bunching.quantile[Quantile == 'Q1']$data),
    data.table(moment = 'bunching_q2', model = bunching.quantile[Quantile == 'Q2']$model, data = bunching.quantile[Quantile == 'Q2']$data),
    data.table(moment = 'bunching_q3', model = bunching.quantile[Quantile == 'Q3']$model, data =  bunching.quantile[Quantile == 'Q3']$data),
    data.table(moment = 'bunching_q4', model = bunching.quantile[Quantile == 'Q4']$model, data =  bunching.quantile[Quantile == 'Q4']$data)
  )
  

  
  
  # Moments involving income differences
  income.data <- data[market.bunching > 0]
  moments.income.bunching <- rbind(
    data.table(moment = 'income_bunching_mean',    data = mean(income.data$market.bunching.income.delta,na.rm=T), model = mean(income.data$INCOME_DELTA_MARKET_MODEL,na.rm=T)),
    data.table(moment = 'income_bunching_distance',model = var(income.data$market.bunching.income.delta - income.data$INCOME_DELTA_MARKET_MODEL,na.rm=T),data = 0),
    data.table(moment = 'income_bunching_income_relationship', data = cov(income.data$market.bunching.income.delta,income.data$market.log_income.mean,use = 'pairwise.complete.obs'), model = cov(income.data$INCOME_DELTA_MARKET_MODEL,income.data$market.log_income.mean,use = 'pairwise.complete.obs')),
    data.table(moment = 'income_bunching_price_relationship' , data = cov(income.data$market.bunching.income.delta,income.data$market.log_price.mean,use = 'pairwise.complete.obs'), model = cov(income.data$INCOME_DELTA_MARKET_MODEL,income.data$market.log_price.mean,use = 'pairwise.complete.obs'))
  )
  
  # Moments involving loan sizes
  moments.mean.loan.sizes <- rbind(
    data.table(moment = 'mean_loan_size_mean',         model = mean(data$LOG_LOAN_SIZE_MEAN_MARKET_MODEL,na.rm=T), data = mean(data$market.log_loan_size.mean,na.rm=T)),
    data.table(moment = 'mean_loan_size_distance',     model = var(data$LOG_LOAN_SIZE_MEAN_MARKET_MODEL - data$market.log_loan_size.mean,na.rm=T),data = 0),
    data.table(moment = 'mean_loan_size_income_relationship', data = cov(data$market.log_loan_size.mean,data$market.log_income.mean,use = 'pairwise.complete.obs'), model = cov(data$LOG_LOAN_SIZE_MEAN_MARKET_MODEL,data$market.log_income.mean,use = 'pairwise.complete.obs')),
    data.table(moment = 'mean_loan_size_price_relationship' , data = cov(data$market.log_loan_size.mean,data$market.log_price.mean,use = 'pairwise.complete.obs'),  model = cov(data$LOG_LOAN_SIZE_MEAN_MARKET_MODEL,data$market.log_income.mean,use = 'pairwise.complete.obs'))
  )
  
  moments.std.loan.sizes <- rbind(
    data.table(moment = 'std_loan_size_mean',         model = mean(data$LOG_LOAN_SIZE_STD_MARKET_MODEL,na.rm=T), data = mean(data$market.log_loan_size.std,na.rm=T)),
    data.table(moment = 'std_loan_size_distance',     model = var(data$LOG_LOAN_SIZE_STD_MARKET_MODEL - data$market.log_loan_size.std,na.rm=T),data = 0),
    data.table(moment = 'std_loan_size_income_relationship', data = cov(data$market.log_loan_size.std,data$market.log_income.mean,use = 'pairwise.complete.obs'), model = cov(data$LOG_LOAN_SIZE_STD_MARKET_MODEL,data$market.log_income.mean,use = 'pairwise.complete.obs')),
    data.table(moment = 'std_loan_size_price_relationship' , data = cov(data$market.log_loan_size.std,data$market.log_price.mean,use = 'pairwise.complete.obs') , model = cov(data$LOG_LOAN_SIZE_STD_MARKET_MODEL,data$market.log_income.mean,use = 'pairwise.complete.obs'))
  )
  
  # Moments involving particulars of the estimation (e.g., number of non-converged markets)
  moments.diagnostics <- rbind(
    data.table(moment = 'convergence.pct', model = sum(data$converged == 0) / nrow(data),data = 1)
  )
  
  
  moments <- rbind(moments.residuals,moments.bunching,moments.income.bunching,moments.mean.loan.sizes,moments.std.loan.sizes,moments.diagnostics)
  
  return(moments)
  
  
  
}

param_vec_to_matrix <- function(x) {
  # Gives in vector of x (e.g., what a solver is doing) and organizes it into the appropriate matrices
   
  # x[1]  -- mu_beta
  # x[2]  -- mu_log_f
  # x[3]  -- d_alpha_d_income
  # x[4]  -- d_beta_d_income
  # x[5]  -- d_gamma_d_income
  # x[6]  -- d_log_f_d_income
  # x[7]  -- d_alpha_d_house
  # x[8]  -- d_beta_d_house
  # x[9]  -- d_gamma_d_house
  # x[10] -- d_log_f_d_house
  # x[11] -- sig_alpha
  # x[12] -- sig_beta
  # x[13] -- sig_gamma
  # x[14] -- sig_log_f
  
    A = matrix(data = c(0,x[1],0,x[2]),
               nrow = 4,ncol = 1,
               dimnames = list(c('alpha','beta','gamma','log_f'),c('A')))
  
  
    B = matrix(data = c(x[3],x[4],x[5],x[6],
                        x[7],x[8],x[9],x[10]),
               nrow = 4,ncol = 2,
               dimnames = list(c('alpha','beta','gamma','log_f'),c('log_income','log_house_price')))
    S = matrix(data = c(x[11],0,0,0,
                        0,x[12],0,0,
                        0,0,x[13],0,
                        0,0,0,x[14]),
               nrow = 4,ncol = 4,
               dimnames = list(c('alpha','beta','gamma','log_f'),c('alpha','beta','gamma','log_f')))
  
    return(list(A=A,B=B,S=S))
}


getAllSharesAndMR <- function(markets.data,raw.draws,pars,linear.model,cores = 1) {
  
  all.markets <- unique(markets.data$MARKET_ID)
  if(cores == 1) {
    for(ii in 1:length(all.markets)) {
      print(all.markets[ii])
      current.market.id <- all.markets[ii]
      current.market.data       <- markets.data[MARKET_ID == current.market.id]
      transformed.draws <- transformRawDraws(raw.draws = raw.draws,market.data = current.market.data,pars = pars)
      temp <- calculateMarginalRevenuesAndCounterfactualShares(market.data  = current.market.data,transformed.draws = transformed.draws,pars = pars,linear.model = linear.model) 
      if(ii == 1) {
        shares_mr <- temp
      } else {
        shares_mr <- rbind(shares_mr,temp)
      }
    }
  } else {
    registerDoParallel(cores = cores)
    shares_mr <- foreach(ii = 1:length(all.markets), .combine = rbind) %dopar% {
      current.market.id <- all.markets[ii]
      current.market.data       <- markets.data[MARKET_ID == current.market.id]
      current.transformed.draws <- transformRawDraws(raw.draws = raw.draws,market.data = current.market.data,pars = pars) 
      calculateMarginalRevenuesAndCounterfactualShares(market.data  = current.market.data,transformed.draws = current.transformed.draws,pars = pars,linear.model = linear.model) 
    }
  }
  
  return(shares_mr)
  
}


solveAllMarkets <- function(markets.data,raw.draws,pars,delta.init = NULL,cores = 1) {
  
  ## Parallelized and non-parallelized
  all.markets <- unique(markets.data$MARKET_ID)
  
  if(cores == 1) {
    
    ## Non-parallelized
    for(ii in 1:length(all.markets)) {
      print(all.markets[ii])
      current.market.id <- all.markets[ii]
      current.market.data       <- markets.data[MARKET_ID == current.market.id]
      
      if(!is.null(delta.init)) {
        current.delta.init <- delta.init[MARKET_ID == current.market.id,c('j','delta')]
      } else {
        current.delta.init <- NULL
      }
      solved.model.current <- solveForDelta(current.market.data,raw.draws,pars,current.delta.init)
      if(ii == 1) {
        solved.model <- solved.model.current
      } else {
        solved.model <- rbind(solved.model,solved.model.current)
      }
    }
  } else {
  
    ## Parallelized
    registerDoParallel(cores = cores)
    solved.model <- foreach(ii = 1:length(all.markets), .combine = rbind) %dopar% {
      current.market.id <- all.markets[ii]
      current.market.data       <- markets.data[MARKET_ID == current.market.id]
      
      if(!is.null(delta.init)) {
        current.delta.init <- delta.init[MARKET_ID == current.market.id,c('j','delta')]
      } else {
        current.delta.init <- NULL
      }
      solveForDelta(current.market.data,raw.draws,pars,current.delta.init)
    }
  }
  
  solved.model[,max.delta := max(delta),by='MARKET_ID']
  solved.model[,min.delta := min(delta),by='MARKET_ID']
  solved.model[max.delta > 30 | min.delta < -30,converged := 0]
  
  return(solved.model) 
  
}



solveForDelta <- function(market.data,raw.draws,pars,delta.init = NULL) {
  
  if(is.null(delta.init)) {
      delta.init = market.data[,c('j')]
      # Assume they're all the same and get the delta that hits overall shares more or less correctly...
      inside.share = sum(market.data$MARKET_SHARE_DATA)
      delta.init$delta <- log(1/4 * (inside.share)/(1-inside.share))
  }
  
  # Transform draws to market data
  transformed.draws <- transformRawDraws(raw.draws = raw.draws,market.data = market.data,pars = pars)
  
  # Run the loop.
  error    = 1
  attempts = 1
  while(error > pars$BLP_TOL & attempts < pars$BLP_ATTEMPTS) {
    with.shares <- calculateModeledMarketShares(delta.init,market.data,transformed.draws,pars,shares.only = T)
    with.shares[,deviation := log(SHARE_DATA) - log(SHARE_MODEL)]
    with.shares[,delta := delta + deviation] 
    
    delta.init <- with.shares[,c('j','delta')]
    error      <- sum(with.shares$deviation^2)
    attempts   <- attempts + 1  
  }  
  if(error < pars$BLP_TOL) {
    converged = 1
  } else {
    converged = 0
  }
    
  # Calculate all the moments, and append the micro moments to the market data
  solution <- calculateModeledMarketShares(delta.init,market.data,transformed.draws,pars,shares.only = F)[,!c('SHARE_DATA')]
  solution$converged <- converged
  solution$attempts <- attempts
  
  return(solution)
}


calculateModeledMarketShares <- function(delta,market.data,transformed.draws,pars,shares.only = F) {
  
  ## Append delta to the market data. Expect xi as a data.table with (j,delta)
  market.data.with.delta <- merge(market.data,delta,by=c('j'))
  
  ## Expand market data with draws
  md <- appendDrawsToData(market.data.with.delta,transformed.draws,pars)
  
  ## Calculate utilities and market shares
  md[,constrained.loan.size := pmin(pars$LTV_LIMIT * exp(log_house_price),market.conforming_loan_limit) * (JUMBO == 'c') + pars$LTV_LIMIT * exp(log_house_price) * (JUMBO == 'j') ] # Maximum loan size the guy can get
  md[,actual.loan.size      := pmin(exp(log_f),constrained.loan.size)]
  md[,numerator.per.lender  := exp(delta + alpha.deviation * RATE_INST + gamma.deviation * (JUMBO == 'c') - beta * (exp(log_f) > constrained.loan.size) - 999 * (actual.loan.size < market.conforming_loan_limit) * (JUMBO == 'j'))] # if it's jumbo and actual loan size is not jumbo, -10 utility; you hate jumbo.
  md[,numerator             := N_LENDERS * numerator.per.lender ]
  md[,denominator           := sum(1 + numerator,na.rm=T),by='i']
  md[,share                 := numerator / denominator]
  
  
  # Aggregate back to product level
  byProduct <- md[,j=list(delta = delta[1],SHARE_MODEL = mean(share,na.rm=T),SHARE_DATA = MARKET_SHARE_DATA[1]),by=c('j')]
  if(shares.only) {
    toReturn <- byProduct
  } else {
    
    ## Calculate micro moments---income deltas, bunching
    md[,is.bunching           := as.numeric(actual.loan.size / market.conforming_loan_limit > 1-pars$BUNCHING_DELTA  & actual.loan.size / market.conforming_loan_limit < 1+pars$BUNCHING_DELTA )]
    md[,is.bunching.ring      := as.numeric(  (actual.loan.size / market.conforming_loan_limit > .9 & actual.loan.size / market.conforming_loan_limit < 1-pars$BUNCHING_DELTA) | (actual.loan.size / market.conforming_loan_limit > 1+pars$BUNCHING_DELTA & actual.loan.size / market.conforming_loan_limit < 1.1) )]
    md[,just.below            := as.integer(actual.loan.size > market.conforming_loan_limit * .95 & actual.loan.size <= market.conforming_loan_limit)]
    md[,just.above            := as.integer(actual.loan.size > market.conforming_loan_limit       & actual.loan.size <= market.conforming_loan_limit * 1.05)]
    
    # Market micro moments
    byMarket  <- md[,j=list(MARKET_ID = MARKET_ID[1],
                            BUNCHING_MARKET_MODEL = weighted.mean(is.bunching,w=share,na.rm=T),
                            INCOME_DELTA_MARKET_MODEL = weighted.mean(exp(log_income),w=is.bunching*share,na.rm=T) - weighted.mean(exp(log_income),w=is.bunching.ring * share,na.rm=T),
                            LOG_LOAN_SIZE_MEAN_MARKET_MODEL = weighted.mean(log(actual.loan.size),w=share,na.rm=T), 
                            LOG_LOAN_SIZE_STD_MARKET_MODEL = weighted.std(log(actual.loan.size),w=share),
                            JUST_BELOW_MODEL = weighted.mean(just.below,w=share,na.rm=T),
                            JUST_ABOVE_MODEL = weighted.mean(just.above,w=share,na.rm=T))]
    byMarket[is.na(BUNCHING_MARKET_MODEL),BUNCHING_MARKET_MODEL := 0]
    byMarket[is.na(INCOME_DELTA_MARKET_MODEL),INCOME_DELTA_MARKET_MODEL := 0]
    
    # To return
    toReturn <- cbind(byMarket,byProduct)
  }
  return(toReturn)
}

calculateMarginalRevenuesAndCounterfactualShares <- function(market.data,transformed.draws,pars,linear.model) {
  
  
  # Calculate the counterfactual delta
  ## Key thing---if we change rate, do it through "RATE_INST" not "RATE"
  market.data[,delta := xi + predict(linear.model,market.data)]
  
  ## Expand market data with draws
  md <- appendDrawsToData(market.data,transformed.draws,pars)
  
  nprods = nrow(market.data)
  
  ## Calculate utilities and market shares
  md[,constrained.loan.size := pmin(pars$LTV_LIMIT * exp(log_house_price),market.conforming_loan_limit) * (JUMBO == 'c') + pars$LTV_LIMIT * exp(log_house_price) * (JUMBO == 'j') ] # Maximum loan size the guy can get
  md[,actual.loan.size      := pmin(exp(log_f),constrained.loan.size)]
  md[,numerator.per.lender  := exp(delta + alpha.deviation * RATE_INST + gamma.deviation * (JUMBO == 'c') - beta * (exp(log_f) > constrained.loan.size) - 999 * (actual.loan.size < market.conforming_loan_limit) * (JUMBO == 'j'))] # if it's jumbo and actual loan size is not jumbo, -10 utility; you hate jumbo.
  md[,numerator             := N_LENDERS * numerator.per.lender ]
  md[,denominator           := sum(1 * (4/nprods) + numerator,na.rm=T),by='i']
  md[,share                 := numerator / denominator]
  
  ## Need shares and marginal revenues
  
  # For bank lenders, markup depends also on share in other product...
  md[,total.share := sum(share,na.rm=T),by=c('i','LENDER')]
  md[,N_LENDERS_TOTAL := sum(N_LENDERS,na.rm=T),by=c('i','LENDER')]
  md[,other.share := total.share - share]
  md[,other.lenders := N_LENDERS_TOTAL - N_LENDERS]
  md[other.lenders == 0,other.lenders := 1]
  
  # First, need actual alpha (not just the alpha deviation.....)
  # In our setting, alpha is negative of the elasticity (e.g., a negative alpha means you dislike high rates.)
  md[,alpha.total           := alpha.deviation + linear.model$coefficients['RATE_INST']] # Add in the total alpha... Expect alpha.total to be negative.
  md[,eqm.markup            := 1 / (-1 * alpha.total) * 1 / (1 - share/N_LENDERS - other.share/other.lenders)]
  #md[,dsdr                  := alpha.total * (1 - share/N_LENDERS) * share/N_LENDERS]
  md[,cons.surplus          := -1 * (log(denominator)+.5772) * (1/alpha.total) * 1/100 * ( (1- (1+0.05)^(-10) ) / (0.05) )]
  
  # Get high-low income for consumer surplus.
  md[,income.bucket := 'middle']
  md[log_income < quantile(md$log_income,.25),income.bucket := 'low']
  md[log_income > quantile(md$log_income,.75),income.bucket := 'high']
  
  
  byProduct <- md[,j=list(MARKET_ID = MARKET_ID[1],delta = delta[1],SHARE_MODEL = mean(share,na.rm=T),
                          #DSDR_MODEL = mean(dsdr,na.rm=T),
                          MARKUP_MODEL = mean(eqm.markup),
                          LOAN_SIZE = weighted.mean(actual.loan.size,w=share,na.rm=T),SURPLUS = mean(cons.surplus,na.rm=T),
                          SURPLUS_HIGH = weighted.mean(cons.surplus,w=(income.bucket == 'high')),
                          SURPLUS_LOW  = weighted.mean(cons.surplus,w=(income.bucket == 'low'))),by=c('j')]
  
  # Surplus is at a market level...
  byProduct[,SURPLUS := byProduct$SURPLUS[1]]
  byProduct[,SURPLUS_HIGH := byProduct$SURPLUS_HIGH[1]]
  byProduct[,SURPLUS_LOW := byProduct$SURPLUS_LOW[1]]
  
  
  toReturn <- byProduct
  
  return(toReturn)
}




appendDrawsToData <- function(market.data,transformed.draws,pars) {
  # Takes a single market, which is a 4-row data, and adds the individual draws, yielding a 4 x N_SIM row data
  repped_md    <- rep(market.data,each = pars$N_SIM)
  repped_draws <- rep(transformed.draws,times = nrow(market.data)) 
  mergedMarketData <- data.table(cbind(repped_md,repped_draws))
  
  return(mergedMarketData)
}

transformRawDraws <- function(raw.draws,market.data,pars) {
  
  ## 0. Make sure we just take the first row of the market data
  market.data <- market.data[1]
  
  ## 1. Create D - D_BAR draws
  
  cov.matrix.demo = matrix(data=c(market.data$market.log_income.std^2,market.data$market.log_income.std * market.data$market.log_price.std * market.data$market.income_price.corr,
                             market.data$market.log_income.std * market.data$market.log_price.std * market.data$market.income_price.corr,market.data$market.log_price.std^2),nrow=2,ncol=2,dimnames=list(c('log_income','log_house_price'),c('log_income','log_house_price')))
  
  demo.transform = matrix(data = c(rep(market.data$market.log_income.mean,pars$N_SIM),rep(market.data$market.log_price.mean,pars$N_SIM)),nrow = pars$N_SIM,ncol = 2,dimnames = list(1:pars$N_SIM,c('log_income','log_house_price'))) + raw.draws$demo.draws %*%   chol(cov.matrix.demo) 
  
  ## 2. Create Parameter draws. Note that because we identify a few mean parameters in the regression (not thru BLP directly) we zero out these means.
  mean.rep   = do.call("rbind", rep(list(t(pars$A)), pars$N_SIM))
  D_demeaned = demo.transform - matrix(data = c(rep(market.data$market.log_income.mean.BAR,pars$N_SIM),rep(market.data$market.log_price.mean.BAR,pars$N_SIM)),nrow = pars$N_SIM,ncol = 2,dimnames = list(1:pars$N_SIM,c('log_income','log_house_price')))
  pref.transform =  D_demeaned %*% t(pars$B) + mean.rep + raw.draws$pref.draws %*% chol(pars$S)
  
  ## 3. Create DEVIATION parameter draws for the purposes of the estimation.
  ## In particular, the average parts alpha_bar and gamma_bar are identified in the regression
  mean.rep   = do.call("rbind", rep(list(t(pars$A)), pars$N_SIM))
  mean.rep[,'alpha'] <- 0
  mean.rep[,'gamma'] <- 0
  D_demeaned = demo.transform - matrix(data = c(rep(market.data$market.log_income.mean.BAR,pars$N_SIM),rep(market.data$market.log_price.mean.BAR,pars$N_SIM)),nrow = pars$N_SIM,ncol = 2,dimnames = list(1:pars$N_SIM,c('log_income','log_house_price')))
  pref.transform.deviation =  D_demeaned %*% t(pars$B) + mean.rep + raw.draws$pref.draws %*% chol(pars$S)
  colnames(pref.transform.deviation) <- paste0(colnames(pref.transform.deviation),'.deviation')
  
  
  ## Return it as a dataframe with person indices
  transformed.draws = data.table(i = 1:pars$N_SIM,demo.transform,pref.transform,pref.transform.deviation)
  
  
  transformed.draws[,constrained.loan.size := pars$LTV_LIMIT * exp(log_house_price) ] # Maximum loan size the guy can get if it was jumbo.
  transformed.draws[,actual.loan.size      := pmin(exp(log_f),constrained.loan.size)]
  
  max.loan.size <- max(transformed.draws$actual.loan.size)
  if(max.loan.size < market.data$market.conforming_loan_limit[1] + 5000) {
    idx <- transformed.draws$actual.loan.size == max.loan.size
    transformed.draws[idx]$log_house_price <- log( (market.data$market.conforming_loan_limit[1] + 5000) / pars$LTV_LIMIT)
    transformed.draws[idx]$log_f           <- log(market.data$market.conforming_loan_limit[1] + 5000 )
    transformed.draws[idx]$log_f.deviation <- log(market.data$market.conforming_loan_limit[1] + 5000 )
    
  }
  
  
  return(transformed.draws)
}

createRawDraws <- function(pars) {
  set.seed(pars$SEED)
  
  ## Creates raw draws that are fed into the simulation.
  ## Need demographic and parameter random draws.
  demo.draws <- matrix(data = rnorm(n = 2 * pars$N_SIM),nrow = pars$N_SIM,ncol = 2,dimnames = list(1:pars$N_SIM,c('log_income','log_house_price')))
  pref.draws <- matrix(data = rnorm(n = 4 * pars$N_SIM),nrow = pars$N_SIM,ncol = 4,dimnames = list(1:pars$N_SIM,c('alpha','beta','gamma','log_f')))
  
  return(list(demo.draws = demo.draws,pref.draws = pref.draws,pars = pars))
}


weighted.std <- function(x,w) {
  # STD = sqrt( E[X^2] - E[X]^2 )
  toRet = sqrt(weighted.mean(x^2,w,na.rm=T) - weighted.mean(x,w,na.rm=T)^2)
}


################# SUPPLY SIDE ################# 

getExcessCapitalization <- function(capital_requirement = NULL) {
  
  # Capitalization data
  data.in <- fread('../data/final/capitalization_data_sep9.csv')
  
  # First: Calculate the number of banks, and the weighted average capital ratio of the available banks
  data.in[,BANK_ECR := BANK_CR - rho_hat]
  baseline.agg <- data.in[,j=list(N_ACTIVE_BASELINE = sum(BANK_ECR > 0,na.rm=T),
                                  PCT_HELD_DATA       = weighted.mean(PCT_HELD,w = N_ORIG,na.rm=T)),by=c('YEAR','CBSA','PURPOSE','JUMBO')]
  
   return(list(baseline_cr = baseline.agg))
}

loadEstimation <- function(estimation_config = 'estimation_5',override_A = NULL,override_B = NULL, override_S = NULL,override_linear_demand = NULL) {
  
  # Load result
  source(paste0('estimation_configs/',estimation_config,'.r'))
  load(paste0('estimation_results/',pars$filename,'.rsave'))
  
  # Get all the deltas and market shares
  
  # 0. Get estimation pars
  pars <- estimation_data$pars
  
  
  
  # 1.  Deal parameters
  new_mx <- param_vec_to_matrix(estimation_data$solution$x)
  pars$A <- new_mx$A
  pars$B <- new_mx$B
  pars$S <- new_mx$S
  
  
  linear <- estimation_data$solution$linear.model
  
  
  # If randomize parameters...
  if(!is.null(override_A)) {
    pars$A <- override_A
    pars$B <- override_B
    pars$S <- override_S
    linear <- override_linear_demand 
  }
  

  # 2. Recover necessary data
  market.data.merged <- estimation_data$solution$market.data.merged
  moments <- estimation_data$solution$moments
  marginal.revenues <- getAllSharesAndMR(markets.data = market.data.merged,raw.draws = estimation_data$raw.draws,pars = pars,linear.model = linear,cores = 10)
  
  result <- list(pars = pars,market.data = market.data.merged,linear.model = linear,moments = moments,mr = marginal.revenues,solution = estimation_data$solution$x,raw.draws = estimation_data$raw.draws)
}


loadSupplyEstimationForCF <- function(estimation_config = 'estimation_5',override_linear_year = NULL, override_linear_supply = NULL,override_nonlinear_supply = NULL) {
  cr = data.table(YEAR = c(2010,2011,2012,2013,2014,2015,2016,2017),rho_hat = c(.04,.04,.04,.045,.055,.06,0.06,0.06),
                       xi_g = c(.2,.2,.2,.2,.2,.2,.2,.2),xi_j = c(.5,.5,.5,.5,.5,.5,.5,.5))
  
  load(paste0('supply_estimation_results/',estimation_config,'.rsave'))
  
  
  if(!is.null(override_linear_year)) {
    linear_year = override_linear_year
    linear = override_linear_supply
    nonlinear = override_nonlinear_supply
    mc.year <- linear_year[,c('YEAR','MC_LIN_YEAR')]
  } else {
    linear = result$mc$linear
    nonlinear = table.D
    mc.year <- table.A[,c('year','value')]
    names(mc.year) <- c('YEAR','MC_LIN_YEAR')
  }
  
  
  mc.risk.jumbo <- data.table(JUMBO = c('c','c','j','j'),FICO = c('l','h','l','h'),
                              MC_LIN_RISK_JUMBO = as.numeric(c(linear$coefficients['FICOl'],
                                         0 ,
                                         linear$coefficients['FICOl'] + linear$coefficients['JUMBOj'] + linear$coefficients['FICOl:JUMBOj'] ,
                                         linear$coefficients['JUMBOj'] )))
  
  mc.labor <- data.table(LENDER = c('b','b','n','n','f','f'),PURPOSE = c('p','r','p','r','p','r'),
                         MC_LIN_LABOR = as.numeric(c(0 ,
                                    linear$coefficients['PURPOSEr']  ,
                                    linear$coefficients['LENDERn']  ,
                                    linear$coefficients['LENDERn'] + linear$coefficients['PURPOSEr'] + linear$coefficients['PURPOSEr:LENDERn'] ,
                                    linear$coefficients['LENDERf']  ,
                                    linear$coefficients['LENDERf'] + linear$coefficients['PURPOSEr'] + linear$coefficients['PURPOSEr:LENDERf'])))
  
  mc.nonlinear <- data.table(MC_NONLIN_GSE = as.numeric(nonlinear[parameter == 'sig_gse']$value),
                             MC_NONLIN_RE = as.numeric(nonlinear[parameter == 'sig_re']$value),
                             MC_NONLIN_PSI = as.numeric(nonlinear[parameter == 'psi']$value))
  
  mc.results <- merge(data.frame(mc.year),data.frame(mc.risk.jumbo),by=NULL)
  mc.results <- merge(mc.results,data.frame(mc.labor),by=NULL)
  mc.results <- merge(mc.results,data.frame(mc.nonlinear),by=NULL)
  mc.results <- data.table(mc.results)
  mc.results <- mc.results[LENDER == 'b' | (LENDER != 'b' & JUMBO == 'c')]
  
  
  return(mc.results)
}


getBankNonlinMarginalCostsForSupplyEstimation <- function(data.in,pars) {
  
  # does non-linear calculations at bank level before aggregating.
  
  
  # Bank-level costs 
  data.in[,ECAP_BANK := BANK_CR - rho_hat]
  data.in[,ACTIVE   := ECAP_BANK > 0.001]
  data.in[,ECAP_BANK := pmax(0.001,BANK_CR - rho_hat)]
  
  data.in[,mc_jumbo_nonlin_balance      := BANK_CR^2 * pars$sig_eq * pars$psi * xi_j * (ECAP_BANK)^(-1 * (pars$psi + 1))]
  data.in[,mc_conforming_nonlin_balance := BANK_CR^2 * pars$sig_eq * pars$psi * xi_g * (ECAP_BANK)^(-1 * (pars$psi + 1))]
  data.in[,mc_conforming_nonlin_gse          := pars$sig_gse]
  data.in[JUMBO == 'j',mc_nonlin := mc_jumbo_nonlin_balance]
  data.in[JUMBO == 'c',mc_nonlin := pmin(mc_conforming_nonlin_balance,mc_conforming_nonlin_gse)]
  data.in[JUMBO == 'j',held := 1]
  data.in[JUMBO == 'c',held := 0]
  data.in[JUMBO == 'c' & mc_conforming_nonlin_balance < mc_conforming_nonlin_gse,held := 1]
  
  # Aggregate to YEAR-CBSA-PURPOSE-JUMBO level
  result <- data.in[,j=list(LENDER = 'b',ACTIVE = sum(ACTIVE),ECAP_BANK = weighted.mean(ECAP_BANK,w = ACTIVE * N_ORIG),BANK_CR = weighted.mean(BANK_CR,w=ACTIVE * N_ORIG),
                            MC_NONLIN = weighted.mean(mc_nonlin,w=ACTIVE * N_ORIG,na.rm=T),HELD = weighted.mean(held,w = ACTIVE * N_ORIG)),by=c('YEAR','CBSA','PURPOSE','JUMBO')]
  
  return(result)
}



calculateMC <- function(estimation.data,bank.data,supply.pars) {
  # Calculates a `linear' and `nonlinear' piece of marginal costs.
  # MC = MC_LIN + MC_NONLIN
  # Also, in equilibrium, R - MARKUP = MC
  
  # Get marginal cost from ``data''
  estimation.data[,MC := RATE - MARKUP_MODEL]
  
  # Get nonlinear cost
  bankNonLinMC <- getBankNonlinMarginalCostsForSupplyEstimation(bank.data,supply.pars)
  
  # Merge them
  merged <- merge(estimation.data,bankNonLinMC,by=c('YEAR','CBSA','PURPOSE','JUMBO','LENDER'),all.x=T)
  merged[LENDER != 'b',HELD := 1]
  merged[LENDER != 'b',MC_NONLIN := supply.pars$sig_gse]
  merged[LENDER != 'b',ACTIVE := N_ACTIVE_BASELINE]
  
  # Get MC lin and do regression.
  merged[,mc_lin := MC - MC_NONLIN]
  merged[,weight := N_LOANS * (1e1 * (YEAR == 2017) + 1e0 * (YEAR != 2017))]
  if(length(unique(merged$YEAR))>1) {
    linear.model <- lm(mc_lin ~ as.factor(YEAR) + FICO*JUMBO + PURPOSE * LENDER,data=merged,weights = merged$weight) 
  } else {
    linear.model <- lm(mc_lin ~ FICO*JUMBO + PURPOSE * LENDER,data=merged,weights = merged$weight) 
  }
  
  # Get MC deviation
  merged[,mc_residual :=  mc_lin - predict(linear.model,merged)]
  merged[,pred.mc := MC_NONLIN + predict(linear.model,merged)]
  md.c <- merged[LENDER == 'b' & JUMBO == 'c']
  md.j <- merged[LENDER == 'b' & JUMBO == 'j']
  
  # Regress MC on CR
  rel.data <- felm(MC ~ BANK_CR,data=md.j)
  rel.mode <- felm(pred.mc ~ BANK_CR,data=md.j)
  overall <- felm(MC ~ pred.mc,data=merged)
  
  # Calculate moments
  moments <- rbind(
    data.table(moment = 'p_slope',value= 1e1* (overall$beta['pred.mc','MC']-1)),
    data.table(moment = 'cr_slope',value=1e1* (rel.data$beta['BANK_CR','MC'] - rel.mode$beta['BANK_CR','pred.mc'])),
    data.table(moment = 'r2',    value = 1e1* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2017))),
    data.table(moment = 'r2',    value = 1e1* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2016))),
    data.table(moment = 'r2',    value = 1e1* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2015))),
    data.table(moment = 'r2',    value = 1e0* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2014))),
    data.table(moment = 'r2',    value = 1e0* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2013))),
    data.table(moment = 'r2',    value = 1e0* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2012))),
    data.table(moment = 'r2',    value = 1e0* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2011))),
    data.table(moment = 'r2',    value = 1e0* weighted.mean(merged$mc_residual^2,na.rm=T,w=merged$N_LOANS * (merged$YEAR == 2010))),
    data.table(moment = 'r_cap_cov', value = 1e2*cov(merged$mc_residual,merged$ECAP_BANK,use = 'complete')),
    data.table(moment = 'held_cap_cov', value = cov(md.c$PCT_HELD_DATA,  md.c$ECAP_BANK,use = 'complete') - cov(md.c$HELD,md.c$ECAP_BANK,use = 'complete')),
    data.table(moment = 'held_cap_var', value = 1e0*var(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T)),
    data.table(moment = 'held_mean',   value = 1e2*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2017))),
    data.table(moment = 'held_mean',   value = 1e1*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2016))),
    data.table(moment = 'held_mean',   value = 1e1*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2015))),
    data.table(moment = 'held_mean',   value = 1e0*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2014))),
    data.table(moment = 'held_mean',   value = 1e0*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2013))),
    data.table(moment = 'held_mean',   value = 1e0*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2012))),
    data.table(moment = 'held_mean',   value = 1e0*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2011))),
    data.table(moment = 'held_mean',   value = 1e0*weighted.mean(md.c$PCT_HELD_DATA - md.c$HELD,na.rm=T,w=md.c$N_LOANS * (md.c$YEAR == 2010)))
  )
  
  return(list(params = supply.pars,moments = moments,linear = linear.model))
}




loadSupplyEstimationData <- function(estimation = 'estimation_5',demand.in = NULL) {
  
  
  # Combines demand estimation with capitalization
  if(!is.null(demand.in)) {
    demand.in <- loadEstimation(estimation)
  }
  
  # Get the markets s we can merge in the capitalization information
  dm <- merge(demand.in$mr[,c('MARKET_ID','j','MARKUP_MODEL')],demand.in$market.data[,c('MARKET_ID','j','YEAR','CBSA','PURPOSE','JUMBO','LENDER','RATE','FICO','RATE_INST','N_LOANS')],by=c('MARKET_ID','j'))
  
  # Capitalization in
  cap.in <- getExcessCapitalization()
  
  # Merge it (ignoring lender) and then just kill off the non-bank stuff
  merged <- merge(dm,cap.in$baseline_cr,by=c('YEAR','CBSA','PURPOSE','JUMBO'))
  
  return(merged)
}



